package loveqq.jdbc.log.proxy;

import loveqq.jdbc.log.config.AgentConfig;
import loveqq.jdbc.log.replace.DruidSqlReplace;
import loveqq.jdbc.log.replace.SqlReplace;

import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

/**
 * Jdbc代理
 */
public class JdbcProxy {
    /**
     * PreparedStatement存储SQL的字段
     */
    public static final String SQL_PROXY_FIELD = "sql$agent";
    /**
     * PreparedStatement存储SQL的字段
     */
    public static final String PARAMS_PROXY_FIELD = "params$agent";
    /**
     * 需要代理的执行SQL方法
     */
    private static final Set<String> EXECUTE_METHODS;
    /**
     * 设置参数的方法
     */
    private static final Set<String> SET_PARAMETER_METHODS;
    /**
     * SQL动态参数替换者
     */
    private static final SqlReplace sqlReplace;

    static {
        EXECUTE_METHODS = Stream.of("execute", "executeQuery", "executeUpdate").collect(Collectors.toSet());

        SET_PARAMETER_METHODS = Stream.of(PreparedStatement.class.getDeclaredMethods())
                .filter(method -> method.getName().startsWith("set"))
                .filter(method -> method.getParameterCount() > 1)
                .map(Method::getName)
                .collect(Collectors.toSet());

        sqlReplace = new DruidSqlReplace();
    }

    /**
     * 核心代理方法
     */
    public static Connection proxy(Connection connection) {
        if (connection == null) {
            return null;
        }
        ClassLoader classLoader = JdbcProxy.class.getClassLoader();
        // 代理Connection
        return (Connection) Proxy.newProxyInstance(classLoader, new Class[]{Connection.class},
                (proxy, connMethod, connArgs) -> {
                    // 这里是Connection对象的方法
                    String methodName = connMethod.getName();
                    // 执行Connection对象的方法获取结果，主要是获取statement对象
                    Object result = connMethod.invoke(connection, connArgs);
                    /*
                        不同statement不同代理，主要是执行SQL的方法，执行前打印SQL+参数
                     */
                    if ("prepareStatement".equals(methodName)) {
                        // 代理prepareStatement(String sql, ...)方法 (第一个参数是SQL)
                        String sql = proxySetSql(connArgs, result);

                        // 代理PreparedStatement对象
                        return Proxy.newProxyInstance(classLoader, new Class[]{PreparedStatement.class},
                                (target, stmtMethod, stmtArgs) -> {
                                    PreparedStatement statement = (PreparedStatement) result;
                                    // 设置参数方法，如：setXxx(参数位置,参数值)
                                    if (SET_PARAMETER_METHODS.contains(stmtMethod.getName())) {
                                        Map<Integer, Object> paramMap = proxySetParam(stmtArgs, statement);
                                        try {
                                            return stmtMethod.invoke(statement, stmtArgs);
                                        } catch (Exception e) {
                                            // 出错后也输出日志
                                            log(sql, paramMap, getDbType(statement));
                                            throw e;
                                        }
                                    }
                                    // 代理执行SQL的方法
                                    return (stmtArgs == null || stmtArgs.length == 0)
                                            // 无参数
                                            ? proxyPreparedStatement(statement, stmtMethod, stmtArgs)
                                            // 带SQL参数
                                            : proxyStatement(statement, stmtMethod, stmtArgs);
                                });

                    } else if ("createStatement".equals(methodName)) {
                        // 代理createStatement方法
                        // 代理Statement对象
                        return Proxy.newProxyInstance(classLoader, new Class[]{Statement.class},
                                (target, stmtMethod, stmtArgs) -> {
                                    Statement statement = (Statement) result;
                                    return proxyStatement(statement, stmtMethod, stmtArgs);
                                });
                    }
                    return result;
                });
    }

    /**
     * 代理Connection对象的preparedStatement方法获取SQL
     */
    private static String proxySetSql(Object[] args, Object result)
            throws NoSuchFieldException, IllegalAccessException {
        String sql = (String) args[0];
        // 获取代理存储SQL的字段，将SQL设置进去
        Field sqlField = result.getClass().getDeclaredField(SQL_PROXY_FIELD);
        sqlField.setAccessible(true);
        sqlField.set(result, sql);

        return sql;
    }

    /**
     * 代理PreparedStatement设置参数的方法
     */
    private static Map<Integer, Object> proxySetParam(Object[] args, PreparedStatement statement)
            throws NoSuchFieldException, IllegalAccessException {
        int pos = (int) args[0];
        Object value = args[1];
        // 获取代理存储params的字段，将参数设置进去
        Field paramsField = statement.getClass().getDeclaredField(PARAMS_PROXY_FIELD);
        paramsField.setAccessible(true);
        // 默认是初始化的 Map<Integer, Object>
        Map<Integer, Object> paramMap = (Map<Integer, Object>) paramsField.get(statement);
        if (paramMap == null) {
            // 这里用Map是因为JDBC：设置参数可以按占位符设置，用List顺序可能会乱
            paramMap = new HashMap<>();
            paramsField.set(statement, paramMap);
        }
        paramMap.put(pos, value);

        return paramMap;
    }

    /**
     * 代理Statement对象
     */
    private static Object proxyStatement(Statement statement, Method method, Object[] args)
            throws Exception {
        // 只代理EXECUTE_METHODS的方法
        if (EXECUTE_METHODS.contains(method.getName())) {
            // 获取SQL和参数，日志输出
            // Statement的执行SQL方法第一个参数是SQL
            String sql = (String) args[0];
            log(sql, Collections.emptyMap(), getDbType(statement));
        }
        // 执行方法
        return method.invoke(statement, args);
    }

    /**
     * 代理PreparedStatement对象
     */
    private static Object proxyPreparedStatement(PreparedStatement statement, Method method, Object[] args)
            throws Exception {
        // 只代理EXECUTE_METHODS的方法
        if (EXECUTE_METHODS.contains(method.getName())) {
            // 获取SQL
            Field sqlField = statement.getClass().getDeclaredField(SQL_PROXY_FIELD);
            sqlField.setAccessible(true);
            String sql = (String) sqlField.get(statement);
            // 获取参数
            Field paramsField = statement.getClass().getDeclaredField(PARAMS_PROXY_FIELD);
            paramsField.setAccessible(true);
            Map<Integer, Object> paramMap = (Map<Integer, Object>) paramsField.get(statement);
            // 日志输出
            log(sql, paramMap, getDbType(statement));
        }
        // 执行方法
        return method.invoke(statement, args);
    }

    /**
     * 日志输入：SQL+参数
     */
    public static void log(String sql, Map<Integer, Object> paramMap, String dbType) {
        // SQL
        sql = (sql != null) ? sql : "";
        // 判断SQL是否属于排除，排除就不打印
        if (AgentConfig.INSTANCE.isExclude(sql)) {
            return;
        }
        // 参数
        String paramsStr;
        // 参数值填充SQL
        String fillSql;
        // 参数最大位置
        int maxParamIndex = Optional.ofNullable(paramMap)
                .filter(map -> !map.isEmpty())
                .flatMap(map -> map.keySet().stream().max(Integer::compareTo))
                .orElse(-1);

        // SQL动态参数替换
        try {
            // 填充替换动态参数：占位符?
            fillSql = sqlReplace.replace(sql, paramMap, dbType);
            // 按顺序拼接参数值
            paramsStr = IntStream.range(0, maxParamIndex)
                    .mapToObj(i -> {
                        Object param = paramMap.get(i + 1);
                        // 字符串用''包装
                        return (param instanceof String) ? ("'" + param + "'") : String.valueOf(param);
                    })
                    .collect(Collectors.joining(", ", "[", "]"));

        } catch (Exception e) {
            // 参数设置或数目错误
            fillSql = "| 参数设置不正确, 无法生成参数填充SQL";
            // 输出参数位置+参数值(位置找不到参数，显示空白"")
            StringBuilder builder = new StringBuilder();
            for (int i = 1; i <= maxParamIndex; i++) {
                builder.append("\n|         ")
                        .append(i)
                        .append(" = ")
                        .append(paramMap.getOrDefault(i, ""));
            }
            paramsStr = (builder.length() > 0) ? builder.toString() : "[]";
        }

        // 日志
        String log =
                "+-----+------------------------+\n" +
                "SQL   : " + sql + "\n" +
                "Param : " + paramsStr + "\n" +
                "..................\n" +
                "FSQL  : "+ fillSql + "\n" +
                "+------------------------------+\n";
        // 日志输出
        System.out.println(log);

    }

    /**
     * 根据 Statement 获取 Connection 获取数据库类型
     */
    private static String getDbType(Statement statement) throws SQLException {
        return statement.getConnection().getMetaData().getDatabaseProductName().toLowerCase();
    }
}
